{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d124d22-de73-436b-86cd-9b162b469be8",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install langchain_community\n",
    "%pip install langchain_experimental \n",
    "%pip install langchain-openai \n",
    "%pip install langchainhub \n",
    "%pip install chromadb \n",
    "%pip install langchain\n",
    "%pip install beautifulsoup4\n",
    "%pip install python-dotenv\n",
    "\n",
    "# new\n",
    "%pip install gradio\n",
    "%pip uninstall uvloop -y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f884314f-870c-4bfb-b6c1-a5b4801ec172",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from langchain_community.document_loaders import WebBaseLoader\n",
    "import bs4\n",
    "import openai\n",
    "from langchain_openai import ChatOpenAI, OpenAIEmbeddings\n",
    "from langchain import hub\n",
    "from langchain_core.output_parsers import StrOutputParser\n",
    "from langchain_core.runnables import RunnablePassthrough\n",
    "import chromadb\n",
    "from langchain_community.vectorstores import Chroma\n",
    "from langchain_experimental.text_splitter import SemanticChunker\n",
    "from langchain_core.runnables import RunnableParallel\n",
    "from dotenv import load_dotenv, find_dotenv\n",
    "from langchain_core.prompts import PromptTemplate\n",
    "\n",
    "# new\n",
    "import asyncio\n",
    "import nest_asyncio\n",
    "asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())\n",
    "nest_asyncio.apply()\n",
    "import gradio as gr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eba3468a-d7c2-4a79-8df2-c335542950f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# variables\n",
    "_ = load_dotenv(dotenv_path='env.txt')\n",
    "os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY')\n",
    "openai.api_key = os.environ['OPENAI_API_KEY']\n",
    "embedding_function = OpenAIEmbeddings()\n",
    "llm = ChatOpenAI(model_name=\"gpt-4o-mini\")\n",
    "str_output_parser = StrOutputParser()\n",
    "user_query = \"What are the advantages of using RAG?\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3ad428a-3eb6-40ec-a1a5-62565ead1e5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### INDEXING ####"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98ccda2c-0f4c-41c5-804d-2227cdf35aa7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load Documents\n",
    "loader = WebBaseLoader(\n",
    "    web_paths=(\"https://kbourne.github.io/chapter1.html\",), \n",
    "    bs_kwargs=dict(\n",
    "        parse_only=bs4.SoupStrainer(\n",
    "            class_=(\"post-content\", \"post-title\", \"post-header\")\n",
    "        )\n",
    "    ),\n",
    ")\n",
    "docs = loader.load()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "927a4c65-aa05-486c-8295-2f99673e7c20",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Split\n",
    "text_splitter = SemanticChunker(embedding_function)\n",
    "splits = text_splitter.split_documents(docs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b13568c-d633-464d-8c43-0d55f34cc8c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Embed\n",
    "vectorstore = Chroma.from_documents(documents=splits, \n",
    "                                    embedding=embedding_function)\n",
    "\n",
    "retriever = vectorstore.as_retriever()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ce8df01-925b-45b5-8fb8-17b5c40c581f",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### RETRIEVAL and GENERATION ####"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fac053d8-b871-4b50-b04e-28dec9fb3b0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prompt\n",
    "prompt = hub.pull(\"jclemens24/rag-prompt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ef30632-13dd-4a34-af33-cb8fab94f169",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Relevance check prompt\n",
    "relevance_prompt_template = PromptTemplate.from_template(\n",
    "    \"\"\"\n",
    "    Given the following question and retrieved context, determine if the context is relevant to the question.\n",
    "    Provide a score from 1 to 5, where 1 is not at all relevant and 5 is highly relevant.\n",
    "    Return ONLY the numeric score, without any additional text or explanation.\n",
    "\n",
    "    Question: {question}\n",
    "    Retrieved Context: {retrieved_context}\n",
    "\n",
    "    Relevance Score:\"\"\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8975479-b3e3-481d-ad7b-08b4eb3faaef",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Post-processing\n",
    "def format_docs(docs):\n",
    "    return \"\\n\\n\".join(doc.page_content for doc in docs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd9db713-f705-4b65-800e-2c4e3d0e4ef4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_score(llm_output):\n",
    "    try:\n",
    "        score = float(llm_output.strip())\n",
    "        return score\n",
    "    except ValueError:\n",
    "        return 0\n",
    "\n",
    "# Chain it all together with LangChain\n",
    "def conditional_answer(x):\n",
    "    relevance_score = extract_score(x['relevance_score'])\n",
    "    if relevance_score < 4:\n",
    "        return \"I don't know.\"\n",
    "    else:\n",
    "        return x['answer']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03d4ffa9-da17-48a3-bc83-196e21cdf0f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "rag_chain_from_docs = (\n",
    "    RunnablePassthrough.assign(context=(lambda x: format_docs(x[\"context\"])))\n",
    "    | RunnableParallel(\n",
    "        {\"relevance_score\": (\n",
    "            RunnablePassthrough()\n",
    "            | (lambda x: relevance_prompt_template.format(question=x['question'], retrieved_context=x['context']))\n",
    "            | llm\n",
    "            | str_output_parser\n",
    "        ), \"answer\": (\n",
    "            RunnablePassthrough()\n",
    "            | prompt\n",
    "            | llm\n",
    "            | str_output_parser\n",
    "        )}\n",
    "    )\n",
    "    | RunnablePassthrough().assign(final_answer=conditional_answer)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc5c2ab0-9191-40f7-abf2-681f1c751429",
   "metadata": {},
   "outputs": [],
   "source": [
    "rag_chain_with_source = RunnableParallel(\n",
    "    {\"context\": retriever, \"question\": RunnablePassthrough()}\n",
    ").assign(answer=rag_chain_from_docs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b30177a-f9ab-45e4-812d-33b0f97325bd",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Question - run the chain\n",
    "result = rag_chain_with_source.invoke(user_query)\n",
    "relevance_score = result['answer']['relevance_score']\n",
    "final_answer = result['answer']['final_answer']\n",
    "\n",
    "print(f\"Relevance Score: {relevance_score}\")\n",
    "print(f\"Final Answer:\\n{final_answer}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d3da0d9-4b61-434b-afae-24678cd25d64",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Gradio Interface\n",
    "def process_question(question):\n",
    "    result = rag_chain_with_source.invoke(question)\n",
    "    relevance_score = result['answer']['relevance_score']\n",
    "    final_answer = result['answer']['final_answer']\n",
    "    sources = [doc.metadata['source'] for doc in result['context']]\n",
    "    source_list = \", \".join(sources)\n",
    "    return relevance_score, final_answer, source_list\n",
    "\n",
    "demo = gr.Interface(\n",
    "    fn=process_question,\n",
    "    inputs=gr.Textbox(label=\"Enter your question\", value=user_query),\n",
    "    outputs=[\n",
    "        gr.Textbox(label=\"Relevance Score\"),\n",
    "        gr.Textbox(label=\"Final Answer\"),\n",
    "        gr.Textbox(label=\"Sources\")\n",
    "    ],\n",
    "    title=\"RAG Question Answering\",\n",
    "    description=\"Enter a question and get the relevance score, final answer, and sources from RAG.\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4eb1e76d-9b12-4938-8b4d-c0184617412b",
   "metadata": {},
   "outputs": [],
   "source": [
    "demo.launch(share=True, debug=True) # without credentials"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b59d3d0e-cb69-4cc5-bbef-db2ee947321b",
   "metadata": {},
   "outputs": [],
   "source": [
    "demo.launch(share=True, debug=True, auth=(\"admin\", \"pass1234\")) # with credentials"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
